{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| [03_language_model/02_Transformer语言模型.ipynb](https://github.com/shibing624/nlp-tutorial/blob/main/03_language_model/02_Transformer语言模型.ipynb)  | 从头训练Transformer语言模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/03_language_model/02_Transformer语言模型.ipynb) |\n",
    "\n",
    "# Transformer语言模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节训练一个 sequence-to-sequence 模型，使用pytorch的 \n",
    "`nn.Transformer <https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer>` module.\n",
    "\n",
    "PyTorch 1.2 基于论文 `Attention is All YouNeed <https://arxiv.org/pdf/1706.03762.pdf>` 实现了一个 Transformer 模型， ``nn.Transformer`` 模块依赖于 attention 机制实现表达输入和输出文本的关系。\n",
    "\n",
    "\n",
    "Transformer模型的结构图：\n",
    "![](../docs/transformer_architecture.jpg)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义模型\n",
    "\n",
    "基于 ``nn.TransformerEncoder`` 模型训练语言模型。\n",
    "\n",
    "语言模型任务是为句子后跟随单词输出一个似然概率，表征这个单词可能出现的概率。\n",
    "\n",
    "首先做 embedding，再做 positional encoding, 表征单词位置关系。``nn.TransformerEncoder`` 由多层`nn.TransformerEncoderLayer <https://pytorch.org/docs/master/nn.html?highlight=transformerencoderlayer#torch.nn.TransformerEncoderLayer>`组成，对于语言模型任务，每个未来可能出现的单词都需要 mask 并预测其概率，为了得到实际的预测单词，``nn.TransformerEncoder``模型的输出后需要接一个 log-Softmax 函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class TransformerModel(nn.Module):\n",
    "\n",
    "    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):\n",
    "        super(TransformerModel, self).__init__()\n",
    "        from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
    "        self.model_type = 'Transformer'\n",
    "        self.src_mask = None\n",
    "        self.pos_encoder = PositionalEncoding(ninp, dropout)\n",
    "        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)\n",
    "        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        self.ninp = ninp\n",
    "        self.decoder = nn.Linear(ninp, ntoken)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def _generate_square_subsequent_mask(self, sz):\n",
    "        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n",
    "        mask = mask.float().masked_fill(mask == 0, float(\n",
    "            '-inf')).masked_fill(mask == 1, float(0.0))\n",
    "        return mask\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.encoder.weight.data.uniform_(-initrange, initrange)\n",
    "        self.decoder.bias.data.zero_()\n",
    "        self.decoder.weight.data.uniform_(-initrange, initrange)\n",
    "\n",
    "    def forward(self, src):\n",
    "        if self.src_mask is None or self.src_mask.size(0) != len(src):\n",
    "            device = src.device\n",
    "            mask = self._generate_square_subsequent_mask(len(src)).to(device)\n",
    "            self.src_mask = mask\n",
    "\n",
    "        src = self.encoder(src) * math.sqrt(self.ninp)\n",
    "        src = self.pos_encoder(src)\n",
    "        output = self.transformer_encoder(src, self.src_mask)\n",
    "        output = self.decoder(output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "``PositionalEncoding`` 模块包括 relative 和 absolute 位置编码，positional encodings 与 embeddings 的维度是一样的，这样两者可以相加。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEncoding(nn.Module):\n",
    "\n",
    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(\n",
    "            0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0).transpose(0, 1)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:x.size(0), :]\n",
    "        return self.dropout(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载数据\n",
    "\n",
    "模型训练过程使用来自 torchtext 的Wikitext-2数据集。vocab 基于 train 数据集构建。``batchify（）``函数将数据集排列成列，在将数据划分为大小为`batch_size``的批次后，删除所有剩余的标记。\n",
    "\n",
    "例如，将字母表作为序列（总长度为26），批量大小为4，我们将字母表分成4个长度为6的序列：\n",
    "\n",
    "\\begin{align}\\begin{bmatrix}\n",
    "  \\text{A} & \\text{B} & \\text{C} & \\ldots & \\text{X} & \\text{Y} & \\text{Z}\n",
    "  \\end{bmatrix}\n",
    "  \\Rightarrow\n",
    "  \\begin{bmatrix}\n",
    "  \\begin{bmatrix}\\text{A} \\\\ \\text{B} \\\\ \\text{C} \\\\ \\text{D} \\\\ \\text{E} \\\\ \\text{F}\\end{bmatrix} &\n",
    "  \\begin{bmatrix}\\text{G} \\\\ \\text{H} \\\\ \\text{I} \\\\ \\text{J} \\\\ \\text{K} \\\\ \\text{L}\\end{bmatrix} &\n",
    "  \\begin{bmatrix}\\text{M} \\\\ \\text{N} \\\\ \\text{O} \\\\ \\text{P} \\\\ \\text{Q} \\\\ \\text{R}\\end{bmatrix} &\n",
    "  \\begin{bmatrix}\\text{S} \\\\ \\text{T} \\\\ \\text{U} \\\\ \\text{V} \\\\ \\text{W} \\\\ \\text{X}\\end{bmatrix}\n",
    "  \\end{bmatrix}\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torchtext.legacy.data.field.Field at 0x7fc1de1219d0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import torchtext\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "TEXT = torchtext.legacy.data.Field(init_token='<sos>',\n",
    "                                   eos_token='<eos>',\n",
    "                                   lower=True)\n",
    "\n",
    "train_txt, val_txt, test_txt = torchtext.legacy.datasets.language_modeling.WikiText2.splits(TEXT)\n",
    "TEXT.build_vocab(train_txt)\n",
    "\n",
    "TEXT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2088628"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_txt.examples[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([104431, 20])\n",
      "torch.Size([21764, 10])\n"
     ]
    }
   ],
   "source": [
    "def batchify(data, bsz):\n",
    "    data = TEXT.numericalize([data.examples[0].text])\n",
    "    # Divide the dataset into bsz parts.\n",
    "    nbatch = data.size(0) // bsz\n",
    "    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
    "    data = data.narrow(0, 0, nbatch * bsz)\n",
    "    # Evenly divide the data across the bsz batches.\n",
    "    data = data.view(bsz, -1).t().contiguous()\n",
    "    return data.to(device)\n",
    "\n",
    "\n",
    "batch_size = 20\n",
    "eval_batch_size = 10\n",
    "train_data = batchify(train_txt, batch_size)\n",
    "val_data = batchify(val_txt, eval_batch_size)\n",
    "test_data = batchify(test_txt, eval_batch_size)\n",
    "\n",
    "print(train_data.shape)\n",
    "print(val_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义生成target文本\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bptt = 35\n",
    "\n",
    "\n",
    "def get_batch(source, i):\n",
    "    seq_len = min(bptt, len(source) - 1 - i)\n",
    "    data = source[i:i+seq_len]\n",
    "    target = source[i+1:i+1+seq_len].view(-1)\n",
    "    return data, target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 试一下模型效果\n",
    "\n",
    "设置超参："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntokens = len(TEXT.vocab.stoi)  # the size of vocabulary\n",
    "emsize = 200  # embedding dimension\n",
    "nhid = 200  # the dimension of the feedforward network model in nn.TransformerEncoder\n",
    "nlayers = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
    "nhead = 2  # the number of heads in the multiheadattention models\n",
    "dropout = 0.2  # the dropout value\n",
    "model = TransformerModel(ntokens, emsize, nhead, nhid,\n",
    "                         nlayers, dropout).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 运行模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "lr = 5.0  # learning rate\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n",
    "\n",
    "\n",
    "def train():\n",
    "    model.train()  # Turn on the train mode\n",
    "    total_loss = 0.\n",
    "    start_time = time.time()\n",
    "    ntokens = len(TEXT.vocab.stoi)\n",
    "    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
    "        data, targets = get_batch(train_data, i)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output.view(-1, ntokens), targets)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        log_interval = 200\n",
    "        if batch % log_interval == 0 and batch > 0:\n",
    "            cur_loss = total_loss / log_interval\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
    "                  'lr {:02.2f} | ms/batch {:5.2f} | '\n",
    "                  'loss {:5.2f} | ppl {:8.2f}'.format(\n",
    "                      epoch, batch, len(\n",
    "                          train_data) // bptt, scheduler.get_lr()[0],\n",
    "                      elapsed * 1000 / log_interval,\n",
    "                      cur_loss, math.exp(cur_loss)))\n",
    "            total_loss = 0\n",
    "            start_time = time.time()\n",
    "\n",
    "\n",
    "def evaluate(eval_model, data_source):\n",
    "    eval_model.eval()  # Turn on the evaluation mode\n",
    "    total_loss = 0.\n",
    "    ntokens = len(TEXT.vocab.stoi)\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, data_source.size(0) - 1, bptt):\n",
    "            data, targets = get_batch(data_source, i)\n",
    "            output = eval_model(data)\n",
    "            output_flat = output.view(-1, ntokens)\n",
    "            total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "    return total_loss / (len(data_source) - 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在validation loss最优时保存模型，在每个epoch结束时调整learning rate。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:369: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| epoch   1 |   200/ 2983 batches | lr 5.00 | ms/batch 491.00 | loss  8.02 | ppl  3046.57\n",
      "| epoch   1 |   400/ 2983 batches | lr 5.00 | ms/batch 499.58 | loss  6.85 | ppl   941.26\n",
      "| epoch   1 |   600/ 2983 batches | lr 5.00 | ms/batch 500.60 | loss  6.42 | ppl   614.55\n",
      "| epoch   1 |   800/ 2983 batches | lr 5.00 | ms/batch 496.18 | loss  6.29 | ppl   537.29\n",
      "| epoch   1 |  1000/ 2983 batches | lr 5.00 | ms/batch 487.90 | loss  6.18 | ppl   481.80\n",
      "| epoch   1 |  1200/ 2983 batches | lr 5.00 | ms/batch 496.27 | loss  6.15 | ppl   467.33\n",
      "| epoch   1 |  1400/ 2983 batches | lr 5.00 | ms/batch 506.12 | loss  6.09 | ppl   441.26\n",
      "| epoch   1 |  1600/ 2983 batches | lr 5.00 | ms/batch 514.11 | loss  6.10 | ppl   447.08\n",
      "| epoch   1 |  1800/ 2983 batches | lr 5.00 | ms/batch 510.11 | loss  5.99 | ppl   400.76\n"
     ]
    }
   ],
   "source": [
    "best_val_loss = float(\"inf\")\n",
    "epochs = 10  # The number of epochs\n",
    "best_model = None\n",
    "MODEL_PATH = 'transformer_lm.pth'\n",
    "for epoch in range(1, epochs + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train()\n",
    "    val_loss = evaluate(model, val_data)\n",
    "    print('-' * 89)\n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '\n",
    "          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),\n",
    "                                     val_loss, math.exp(val_loss)))\n",
    "    print('-' * 89)\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        best_model = model\n",
    "        torch.save(best_model.state_dict(), MODEL_PATH)\n",
    "\n",
    "    scheduler.step()\n",
    "\n",
    "best_model.load_state_dict(torch.load(MODEL_PATH))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate the model with the test dataset\n",
    "-------------------------------------\n",
    "\n",
    "Apply the best model to check the result with the test dataset.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss = evaluate(best_model, test_data)\n",
    "print('=' * 89)\n",
    "print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(\n",
    "    test_loss, math.exp(test_loss)))\n",
    "print('=' * 89)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.remove('transformer_lm.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}